"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


Computes and plots cluster statistics for the mice dataset (Fig. 7).


"""


import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))


import numpy as np
import pandas as pd


import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

color_list =["#ffffff",
"#4ba706",
"#a2007e",
"#806dcb",
"#5eb275",
"#ca3b01",
"#01a4d6",
"#b77600",
"#a39643",
"#cc6ea9",
"#1e5e39",
"#cb5b5a"]

cmap = ListedColormap(color_list)

cmap_now = ListedColormap(color_list[1:])

plt.style.use('alex_paper')

import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 1.7


datadir = '../paper_data/mice_data_march_april2017/'

clusterdir = os.path.join(datadir, 'clustersI')


figdir = '../figures/micenet_march_april2017'


net_file = os.path.join(datadir,'micenet_2017_02_28_to_2017_05_01')


raise Exception

#%%


multi_res = pd.read_pickle( os.path.join(datadir, 'multi_weeks_partitions.pickle'))

sex_dict = multi_res.pop('sex_dict')

tau_ws = sorted([k for k in multi_res.keys() if isinstance(k,float)])


#%% avg group size per week


taus_to_plot = [1.0, 86400.0]

fig, (ax,ax2,ax3) = plt.subplots(3,1,figsize=(6.4,9), sharex=True)

x_shift = [-0.1,0.1]


labels = [r'$\tau_w$ = 1 s', r'$\tau_w$ = 24 h']

for i, tau_w in enumerate(taus_to_plot):
    avg_g_size = [np.mean([len(c) for c in p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
    std_g_size = [np.std([len(c) for c in p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
        
        
    ax.errorbar(np.arange(len(avg_g_size))+x_shift[i]+1,
                             avg_g_size,
                             yerr=std_g_size,
                             fmt='o-',
                             label=labels[i],
                             color=color_list[4+i])


ax.legend()

# female groups
for i, tau_w in enumerate(taus_to_plot):
    
    avg_g_size = [np.mean([len(c.intersection(sex_dict['female'])) for c in \
                            p.cluster_list + q.cluster_list if len(c.intersection(sex_dict['female']))>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                      multi_res[tau_w]['partitions_back'])]
    std_g_size = [np.std([len(c.intersection(sex_dict['female'])) for c in \
                            p.cluster_list + q.cluster_list if len(c.intersection(sex_dict['female']))>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                      multi_res[tau_w]['partitions_back'])]
        
        
    ax2.errorbar(np.arange(len(avg_g_size))+x_shift[i]+1,
                              avg_g_size,
                              yerr=std_g_size,
                              fmt='o-',
                              label=labels[i],
                              color=color_list[4+i])


    
# proportion male/female
for i, tau_w in enumerate(taus_to_plot):
    
    avg_g_size = [np.mean([len(c.intersection(sex_dict['female']))/len(c) for c in \
                           p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
    std_g_size = [np.std([len(c.intersection(sex_dict['female']))/len(c) for c in \
                           p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
        
        
    ax3.errorbar(np.arange(len(avg_g_size))+x_shift[i]+1,
                             avg_g_size,
                             yerr=std_g_size,
                             fmt='o-',
                             label=labels[i],
                             color=color_list[4+i])    

ax.set_ylim([0,95])
ax2.set_ylim([0,95])
ax3.set_ylim([0,1])

ax.set_ylabel('Group size')
ax2.set_ylabel('Num. females per group')
ax3.set_ylabel('Proportion of females')
ax3.set_xlabel('week')
ax3.set_xticks(np.arange(1,10))
ax3.set_yticks(np.linspace(0,1,6))

ax.text(-0.2,1.0, r'\textbf{A}', transform=ax.transAxes, useTex=True)
ax2.text(-0.2,1.0, r'\textbf{B}', transform=ax2.transAxes, useTex=True)
ax3.text(-0.2,1.0, r'\textbf{C}', transform=ax3.transAxes, useTex=True)

fig.align_ylabels([ax,ax2,ax3])


#%%

plt.savefig(os.path.join(figdir, 'mice_weekly_group_stats.pdf'))
plt.savefig(os.path.join(figdir, 'mice_weekly_group_stats.png'),dpi=600)

#%% figure with also NVI


from matplotlib.gridspec import GridSpec
# we will use nested gridspecs

fig = plt.figure(figsize=(10.3,7.3))
gs = GridSpec(6, 2, figure=fig)

ax0 = fig.add_subplot(gs[0:3,0:1])
ax1 = fig.add_subplot(gs[3:6,0:1], sharex=ax0)

ax2 =  fig.add_subplot(gs[0:2,1:2])
ax3 =  fig.add_subplot(gs[2:4,1:2],sharex=ax2)
ax4 =  fig.add_subplot(gs[4:6,1:2],sharex=ax2)

nvi_data = pd.read_pickle(os.path.join(datadir,'mice_nvi_plot_data.pickle'))

ax0.errorbar(nvi_data['taus_to_plot'], nvi_data['avg_NVI'], 
             yerr=nvi_data['std_NVI'], fmt='o-', color=color_list[2])

ax1.errorbar(nvi_data['taus_to_plot'], nvi_data['avg_csize'], yerr=nvi_data['std_csize'],
             fmt='o-', color=color_list[2])

ax0.set_xscale('log')


ax1.set_xlabel(r'$\tau_w$ [s]')
ax0.set_ylabel('Norm. Var. Inf.')
ax1.set_ylabel('avg. group size')

ax0.axhline(y=0, ls='--',color='k', linewidth=0.5)
ax0.set_yticks([0,0.0025,0.005,0.0075])

ax1.set_ylim((0,53))
ax1.set_yticks([0,10,20,30,40,50])
ax1.set_yticks([0,10,20,30,40,50])

taus_to_plot = [1.0, 86400.0]

x_shift = [-0.1,0.1]

labels = [r'$\tau_w$ = 1 s', r'$\tau_w$ = 24 h']

for i, tau_w in enumerate(taus_to_plot):
    avg_g_size = [np.mean([len(c) for c in p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
    std_g_size = [np.std([len(c) for c in p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
        
        
    ax2.errorbar(np.arange(len(avg_g_size))+x_shift[i]+1,
                             avg_g_size,
                             yerr=std_g_size,
                             fmt='o-',
                             label=labels[i],
                             color=color_list[4+i])


ax2.legend()

# female groups
for i, tau_w in enumerate(taus_to_plot):
    
    avg_g_size = [np.mean([len(c.intersection(sex_dict['female'])) for c in \
                            p.cluster_list + q.cluster_list if len(c.intersection(sex_dict['female']))>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                      multi_res[tau_w]['partitions_back'])]
    std_g_size = [np.std([len(c.intersection(sex_dict['female'])) for c in \
                            p.cluster_list + q.cluster_list if len(c.intersection(sex_dict['female']))>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                      multi_res[tau_w]['partitions_back'])]
        
        
    ax3.errorbar(np.arange(len(avg_g_size))+x_shift[i]+1,
                              avg_g_size,
                              yerr=std_g_size,
                              fmt='o-',
                              label=labels[i],
                              color=color_list[4+i])


    
# proportion male/female
for i, tau_w in enumerate(taus_to_plot):
    
    avg_g_size = [np.mean([len(c.intersection(sex_dict['female']))/len(c) for c in \
                           p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
    std_g_size = [np.std([len(c.intersection(sex_dict['female']))/len(c) for c in \
                           p.cluster_list + q.cluster_list if len(c)>1]) \
                      for p,q in zip(multi_res[tau_w]['partitions_forw'],
                                     multi_res[tau_w]['partitions_back'])]
        
        
    ax4.errorbar(np.arange(len(avg_g_size))+x_shift[i]+1,
                             avg_g_size,
                             yerr=std_g_size,
                             fmt='o-',
                             label=labels[i],
                             color=color_list[4+i])    

ax2.set_ylim([0,95])
ax3.set_ylim([0,95])
ax4.set_ylim([0,1])

ax2.set_ylabel('avg. group size')
ax2.set_yticks([0,25,50,75])


ax3.set_ylabel('num. females\nper group')
ax3.set_yticks([0,25,50,75])

ax4.set_ylabel('proportion\nof females')
ax4.set_xlabel('week')
ax4.set_xticks(np.arange(1,10))
ax4.set_yticks(np.linspace(0,1,6))


ax0.text(-0.2,1.0, r'\textbf{A}', transform=ax0.transAxes, useTex=True)
ax1.text(-0.2,1.0, r'\textbf{B}', transform=ax1.transAxes, useTex=True)
ax2.text(-0.25,1.0, r'\textbf{C}', transform=ax2.transAxes, useTex=True)
ax3.text(-0.25,1.0, r'\textbf{D}', transform=ax3.transAxes, useTex=True)
ax4.text(-0.25,1.0, r'\textbf{E}', transform=ax4.transAxes, useTex=True)

fig.align_ylabels([ax2,ax3,ax4])

#%%

plt.savefig(os.path.join(figdir, 'fig_mice_NVI_and_weekly_group_stats.pdf'))
plt.savefig(os.path.join(figdir, 'fig_mice_NVI_and_weekly_group_stats.png'),dpi=600)


